{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/02-negative-mining.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/02-negative-mining.ipynb)\n",
    "\n",
    "To perform the negative mining step we must create a vector database to store encoded passages, and allow us to search for similar passages that do not match the query we're searching with. This requires two things:\n",
    "\n",
    "* a pre-existing retriever model to build encodings - for this we will use a model from the *sentence-transformers* library\n",
    "* a vector DB to store encodings - for this we will use Pinecone as it is an free and easy vector DB to setup, which is fast at scale\n",
    "\n",
    "Let's load the model first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SentenceTransformer(\n",
       "  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: DistilBertModel \n",
       "  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
       ")"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "model = SentenceTransformer('msmarco-distilbert-base-tas-b')\n",
    "model.max_seq_length = 256\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initializing the Index\n",
    "\n",
    "Now we need a place to store these embeddings and enable a efficient vector search through them all. To do that we use Pinecone, we can get a [free API key](https://app.pinecone.io/) and enter it below where we will initialize our connection to Pinecone and create a new index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pinecone import Pinecone\n",
    "\n",
    "# initialize connection to pinecone (get API key at app.pinecone.io)\n",
    "api_key = os.environ.get('PINECONE_API_KEY') or 'PINECONE_API_KEY'\n",
    "\n",
    "# configure client\n",
    "pc = Pinecone(api_key=api_key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we setup our index specification, this allows us to define the cloud provider and region where we want to deploy our index. You can find a list of all [available providers and regions here](https://docs.pinecone.io/docs/projects)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pinecone import ServerlessSpec\n",
    "\n",
    "cloud = os.environ.get('PINECONE_CLOUD') or 'aws'\n",
    "region = os.environ.get('PINECONE_REGION') or 'us-east-1'\n",
    "\n",
    "spec = ServerlessSpec(cloud=cloud, region=region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the index:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index_name = \"negative-mining\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "# check if index already exists (it shouldn't if this is first time)\n",
    "if index_name not in pc.list_indexes().names():\n",
    "    # if does not exist, create index\n",
    "    pc.create_index(\n",
    "        index_name,\n",
    "        dimension=model.get_sentence_embedding_dimension(),\n",
    "        metric='dotproduct',\n",
    "        spec=spec\n",
    "    )\n",
    "    # wait for index to be initialized\n",
    "    while not pc.describe_index(index_name).status['ready']:\n",
    "        time.sleep(1)\n",
    "\n",
    "# connect to index\n",
    "index = pc.Index(index_name)\n",
    "# view index stats\n",
    "index.describe_index_stats()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we encode the passages and store in the `negative-mining` index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "def get_text():\n",
    "    with open('data/pairs.tsv', 'r', encoding='utf-8') as fp:\n",
    "        lines = fp.read().split('\\n')\n",
    "    for line in tqdm(lines):\n",
    "        try:\n",
    "            query, passage = line.split('\\t')\n",
    "            yield query, passage\n",
    "        except ValueError:\n",
    "            # in case of malformed data, pass onto next row\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f123d57309b042eca4ce279ec0aff06e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'dimension': 768,\n",
       " 'index_fullness': 0.0,\n",
       " 'namespaces': {'': {'vector_count': 67840}}}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pair_gen = get_text()\n",
    "\n",
    "pairs = []\n",
    "to_upsert = []\n",
    "passage_batch = []\n",
    "id_batch = []\n",
    "batch_size = 64\n",
    "\n",
    "for i, (query, passage) in enumerate(pair_gen):\n",
    "    pairs.append((query, passage))\n",
    "    # we do this to avoid passage duplication in the vector DB\n",
    "    if passage not in passage_batch: \n",
    "        passage_batch.append(passage)\n",
    "        id_batch.append(str(i))\n",
    "    # on reaching batch_size, we encode and upsert\n",
    "    if len(passage_batch) == batch_size:\n",
    "        embeds = model.encode(passage_batch).tolist()\n",
    "        # upload to index\n",
    "        index.upsert(vectors=list(zip(id_batch, embeds)))\n",
    "        # refresh batches\n",
    "        passage_batch = []\n",
    "        id_batch = []\n",
    "        \n",
    "# check number of vectors in the index\n",
    "index.describe_index_stats()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The database is setup for us to begin the *negative mining* step. We will loop through each query in `pairs`, returning *10* of the most similar passage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33e05700fb8a43daadf39f5c2f2166d5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "batch_size = 100\n",
    "triplets = []\n",
    "\n",
    "for i in tqdm(range(0, len(pairs), batch_size)):\n",
    "    # embed queries and query pinecone in batches to minimize network latency\n",
    "    i_end = min(i+batch_size, len(pairs))\n",
    "    queries = [pair[0] for pair in pairs[i:i_end]]\n",
    "    pos_passages = [pair[1] for pair in pairs[i:i_end]]\n",
    "    # create query embeddings\n",
    "    query_embs = model.encode(queries, convert_to_tensor=True, show_progress_bar=False)\n",
    "    # search for top_k most similar passages\n",
    "    res = index.query(vector=query_embs.tolist(), top_k=10)\n",
    "    # iterate through queries and find negatives\n",
    "    for query, pos_passage, query_res in zip(queries, pos_passages, res['results']):\n",
    "        top_results = query_res['matches']\n",
    "        # shuffle results so they are in random order\n",
    "        random.shuffle(top_results)\n",
    "        for hit in top_results:\n",
    "            neg_passage = pairs[int(hit['id'])][1]\n",
    "            # check that we're not just returning the positive passage\n",
    "            if neg_passage != pos_passage:\n",
    "                # if not we can add this to our (Q, P+, P-) triplets\n",
    "                triplets.append(query+'\\t'+pos_passage+'\\t'+neg_passage)\n",
    "                break\n",
    "\n",
    "with open('data/triplets.tsv', 'w', encoding='utf-8') as fp:\n",
    "    fp.write('\\n'.join(triplets))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "pc.delete_index(index_name)  # delete the index when done to avoid higher charges (if using multiple pods)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With that we now have even more *(query, passage) pairs*, that are both positive and negative matches. The next step in GPL will see us scoring all of these pairs using a cross-encoder model."
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "common-cu110.m91",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
